{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "segmentation.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.1"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cZCM65CBt1CJ"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "JOgMcEajtkmg",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rCSP-dbMw88x"
      },
      "source": [
        "# 图像分割"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NEWs8JXRuGex"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/segmentation\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    在 tensorFlow.google.cn 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/segmentation.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    在 Google Colab 中运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/segmentation.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    在 GitHub 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />下载 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sMP7mglMuGT2"
      },
      "source": [
        "这篇教程将重点讨论图像分割任务，使用的是改进版的 [U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/)。\n",
        "\n",
        "## 什么是图像分割？\n",
        "目前你已经了解在图像分类中，神经网络的任务是给每张输入图像分配一个标签或者类别。但是，有时你想知道一个物体在一张图像中的位置、这个物体的形状、以及哪个像素属于哪个物体等等。这种情况下你会希望分割图像，也就是给图像中的每个像素各分配一个标签。因此，图像分割的任务是训练一个神经网络来输出该图像对每一个像素的掩码。这对从更底层，即像素层级，来理解图像很有帮助。图像分割在例如医疗图像、自动驾驶车辆以及卫星图像等领域有很多应用。\n",
        "\n",
        "本教程将使用的数据集是 [Oxford-IIIT Pet 数据集](https://www.robots.ox.ac.uk/~vgg/data/pets/)，由 Parkhi *et al.* 创建。该数据集由图像、图像所对应的标签、以及对像素逐一标记的掩码组成。掩码其实就是给每个像素的标签。每个像素分别属于以下三个类别中的一个：\n",
        "*   类别 1：像素是宠物的一部分。\n",
        "*   类别 2：像素是宠物的轮廓。\n",
        "*   类别 3：以上都不是/外围像素。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "MQmKthrSBCld",
        "colab": {}
      },
      "source": [
        "!pip install git+https://github.com/tensorflow/examples.git"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "YQX7R4bhZy5h",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  # %tensorflow_version 只存在于 Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "g87--n2AtyO_",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "from tensorflow_examples.models.pix2pix import pix2pix\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "tfds.disable_progress_bar()\n",
        "\n",
        "from IPython.display import clear_output\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oWe0_rQM4JbC"
      },
      "source": [
        "## 下载 Oxford-IIIT Pets 数据集\n",
        "\n",
        "这个数据集已经集成在 Tensorflow datasets 中，只需下载即可。图像分割掩码在版本 3.0.0 中才被加入，因此我们特别选用这个版本。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "40ITeStwDwZb",
        "colab": {}
      },
      "source": [
        "dataset, info = tfds.load('oxford_iiit_pet:3.0.0', with_info=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rJcVdj_U4vzf"
      },
      "source": [
        "下面的代码进行了一个简单的图像翻转扩充。然后，将图像标准化到 [0,1]。最后，如上文提到的，像素点在图像分割掩码中被标记为 {1, 2, 3} 中的一个。为了方便起见，我们将分割掩码都减 1，得到了以下的标签：{0, 1, 2}。 "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "FD60EbcAQqov",
        "colab": {}
      },
      "source": [
        "def normalize(input_image, input_mask):\n",
        "  input_image = tf.cast(input_image, tf.float32)/255.0\n",
        "  input_mask -= 1\n",
        "  return input_image, input_mask"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2NPlCnBXQwb1",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def load_image_train(datapoint):\n",
        "  input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
        "  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
        "\n",
        "  if tf.random.uniform(()) > 0.5:\n",
        "    input_image = tf.image.flip_left_right(input_image)\n",
        "    input_mask = tf.image.flip_left_right(input_mask)\n",
        "\n",
        "  input_image, input_mask = normalize(input_image, input_mask)\n",
        "\n",
        "  return input_image, input_mask"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Zf0S67hJRp3D",
        "colab": {}
      },
      "source": [
        "def load_image_test(datapoint):\n",
        "  input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
        "  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
        "\n",
        "  input_image, input_mask = normalize(input_image, input_mask)\n",
        "\n",
        "  return input_image, input_mask"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "65-qHTjX5VZh"
      },
      "source": [
        "数据集已经包含了所需的测试集和训练集划分，所以我们也延续使用相同的划分。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "yHwj2-8SaQli",
        "colab": {}
      },
      "source": [
        "TRAIN_LENGTH = info.splits['train'].num_examples\n",
        "BATCH_SIZE = 64\n",
        "BUFFER_SIZE = 1000\n",
        "STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "39fYScNz9lmo",
        "colab": {}
      },
      "source": [
        "train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "test = dataset['test'].map(load_image_test)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "DeFwFDN6EVoI",
        "colab": {}
      },
      "source": [
        "train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()\n",
        "train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
        "test_dataset = test.batch(BATCH_SIZE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Xa3gMAE_9qNa"
      },
      "source": [
        "我们来看一下数据集中的一例图像以及它所对应的掩码。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3N2RPAAW9q4W",
        "colab": {}
      },
      "source": [
        "def display(display_list):\n",
        "  plt.figure(figsize=(15, 15))\n",
        "\n",
        "  title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
        "\n",
        "  for i in range(len(display_list)):\n",
        "    plt.subplot(1, len(display_list), i+1)\n",
        "    plt.title(title[i])\n",
        "    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))\n",
        "    plt.axis('off')\n",
        "  plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "a6u_Rblkteqb",
        "colab": {}
      },
      "source": [
        "for image, mask in train.take(1):\n",
        "  sample_image, sample_mask = image, mask\n",
        "display([sample_image, sample_mask])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FAOe93FRMk3w"
      },
      "source": [
        "## 定义模型\n",
        "这里用到的模型是一个改版的 U-Net。U-Net 由一个编码器（下采样器（downsampler））和一个解码器（上采样器（upsampler））组成。为了学习到鲁棒的特征，同时减少可训练参数的数量，这里可以使用一个预训练模型作为编码器。因此，这项任务中的编码器将使用一个预训练的 MobileNetV2 模型，它的中间输出值将被使用。解码器将使用在 TensorFlow Examples 中的 [Pix2pix tutorial](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) 里实施过的升频取样模块。  \n",
        "\n",
        "输出信道数量为 3 是因为每个像素有三种可能的标签。把这想象成一个多类别分类，每个像素都将被分到三个类别当中。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "c6iB4iMvMkX9",
        "colab": {}
      },
      "source": [
        "OUTPUT_CHANNELS = 3"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "W4mQle3lthit"
      },
      "source": [
        "如之前提到的，编码器是一个预训练的 MobileNetV2 模型，它在 [tf.keras.applications](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/applications) 中已被准备好并可以直接使用。编码器中包含模型中间层的一些特定输出。注意编码器在模型的训练过程中是不会被训练的。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "liCeLH0ctjq7",
        "colab": {}
      },
      "source": [
        "base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)\n",
        "\n",
        "# 使用这些层的激活设置\n",
        "layer_names = [\n",
        "    'block_1_expand_relu',   # 64x64\n",
        "    'block_3_expand_relu',   # 32x32\n",
        "    'block_6_expand_relu',   # 16x16\n",
        "    'block_13_expand_relu',  # 8x8\n",
        "    'block_16_project',      # 4x4\n",
        "]\n",
        "layers = [base_model.get_layer(name).output for name in layer_names]\n",
        "\n",
        "# 创建特征提取模型\n",
        "down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)\n",
        "\n",
        "down_stack.trainable = False"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KPw8Lzra5_T9"
      },
      "source": [
        "解码器/升频取样器是简单的一系列升频取样模块，在 TensorFlow examples 中曾被实施过。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "p0ZbfywEbZpJ",
        "colab": {}
      },
      "source": [
        "up_stack = [\n",
        "    pix2pix.upsample(512, 3),  # 4x4 -> 8x8\n",
        "    pix2pix.upsample(256, 3),  # 8x8 -> 16x16\n",
        "    pix2pix.upsample(128, 3),  # 16x16 -> 32x32\n",
        "    pix2pix.upsample(64, 3),   # 32x32 -> 64x64\n",
        "]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "45HByxpVtrPF",
        "colab": {}
      },
      "source": [
        "def unet_model(output_channels):\n",
        "\n",
        "  # 这是模型的最后一层\n",
        "  last = tf.keras.layers.Conv2DTranspose(\n",
        "      output_channels, 3, strides=2,\n",
        "      padding='same', activation='softmax')  #64x64 -> 128x128\n",
        "\n",
        "  inputs = tf.keras.layers.Input(shape=[128, 128, 3])\n",
        "  x = inputs\n",
        "\n",
        "  # 在模型中降频取样\n",
        "  skips = down_stack(x)\n",
        "  x = skips[-1]\n",
        "  skips = reversed(skips[:-1])\n",
        "\n",
        "  # 升频取样然后建立跳跃连接\n",
        "  for up, skip in zip(up_stack, skips):\n",
        "    x = up(x)\n",
        "    concat = tf.keras.layers.Concatenate()\n",
        "    x = concat([x, skip])\n",
        "\n",
        "  x = last(x)\n",
        "\n",
        "  return tf.keras.Model(inputs=inputs, outputs=x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "j0DGH_4T0VYn"
      },
      "source": [
        "## 训练模型\n",
        "现在，要做的只剩下编译和训练模型了。这里用到的损失函数是 losses.sparse_categorical_crossentropy。使用这个损失函数是因为神经网络试图给每一个像素分配一个标签，和多类别预测是一样的。在正确的分割掩码中，每个像素点的值是 {0,1,2} 中的一个。同时神经网络也输出三个信道。本质上，每个信道都在尝试学习预测一个类别，而 losses.sparse_categorical_crossentropy 正是这一情形下推荐使用的损失函数。根据神经网络的输出值，分配给每个像素的标签为输出值最高的信道所表示的那一类。这就是 create_mask 函数所做的工作。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "6he36HK5uKAc",
        "colab": {}
      },
      "source": [
        "model = unet_model(OUTPUT_CHANNELS)\n",
        "model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',\n",
        "              metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Tc3MiEO2twLS"
      },
      "source": [
        "我们试着运行一下模型，看看它在训练之前给出的预测值。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "UwvIKLZPtxV_",
        "colab": {}
      },
      "source": [
        "def create_mask(pred_mask):\n",
        "  pred_mask = tf.argmax(pred_mask, axis=-1)\n",
        "  pred_mask = pred_mask[..., tf.newaxis]\n",
        "  return pred_mask[0]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "YLNsrynNtx4d",
        "colab": {}
      },
      "source": [
        "def show_predictions(dataset=None, num=1):\n",
        "  if dataset:\n",
        "    for image, mask in dataset.take(num):\n",
        "      pred_mask = model.predict(image)\n",
        "      display([image[0], mask[0], create_mask(pred_mask)])\n",
        "  else:\n",
        "    display([sample_image, sample_mask,\n",
        "             create_mask(model.predict(sample_image[tf.newaxis, ...]))])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "X_1CC0T4dho3",
        "colab": {}
      },
      "source": [
        "show_predictions()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "22AyVYWQdkgk"
      },
      "source": [
        "我们来观察模型是怎样随着训练而改善的。为达成这一目的，下面将定义一个 callback 函数。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wHrHsqijdmL6",
        "colab": {}
      },
      "source": [
        "class DisplayCallback(tf.keras.callbacks.Callback):\n",
        "  def on_epoch_end(self, epoch, logs=None):\n",
        "    clear_output(wait=True)\n",
        "    show_predictions()\n",
        "    print ('\\nSample Prediction after epoch {}\\n'.format(epoch+1))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "StKDH_B9t4SD",
        "colab": {}
      },
      "source": [
        "EPOCHS = 20\n",
        "VAL_SUBSPLITS = 5\n",
        "VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n",
        "\n",
        "model_history = model.fit(train_dataset, epochs=EPOCHS,\n",
        "                          steps_per_epoch=STEPS_PER_EPOCH,\n",
        "                          validation_steps=VALIDATION_STEPS,\n",
        "                          validation_data=test_dataset,\n",
        "                          callbacks=[DisplayCallback()])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "P_mu0SAbt40Q",
        "colab": {}
      },
      "source": [
        "loss = model_history.history['loss']\n",
        "val_loss = model_history.history['val_loss']\n",
        "\n",
        "epochs = range(EPOCHS)\n",
        "\n",
        "plt.figure()\n",
        "plt.plot(epochs, loss, 'r', label='Training loss')\n",
        "plt.plot(epochs, val_loss, 'bo', label='Validation loss')\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss Value')\n",
        "plt.ylim([0, 1])\n",
        "plt.legend()\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "unP3cnxo_N72"
      },
      "source": [
        "## 做出预测"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7BVXldSo-0mW"
      },
      "source": [
        "我们来做几个预测。为了节省时间，这里只使用很少的周期（epoch）数，但是你可以设置更多的数量以获得更准确的结果。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ikrzoG24qwf5",
        "colab": {}
      },
      "source": [
        "show_predictions(test_dataset, 3)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "R24tahEqmSCk"
      },
      "source": [
        "## 接下来\n",
        "现在你已经对图像分割是什么以及它的工作原理有所了解。你可以在本教程里尝试使用不同的中间层输出值，或者甚至使用不同的预训练模型。你也可以去 Kaggle 举办的 [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) 图像分割挑战赛上挑战自己。\n",
        "\n",
        "你也可以看看 [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) 上面其他的你可以使用自己数据进行再训练的模型。"
      ]
    }
  ]
}
